import numpy as np
from PDE1d import PDE1d
from ODE1d import ODE1d
from simulation_plots import make_plots
import tqdm
from scipy.special import xlogy

# use for all simulations
dz = 0.1
N  = 60
dt = 0.01


def f1(x,z):
    '''Evaluate the cost function f1 at x and z (can be vector-valued)'''
    return 1.-np.around((1.+np.exp(-3.*(z-x)))**-1,decimals=16)


def grad_f1_x(x,z):
    '''Gradient of f1 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(-3.*(z-x))
    return 3*exponent_ / (exponent_+1)**2


def f2(x,z):
    '''Evaluate the cost function f2 at x and z (can be vector-valued)'''
    return np.around((1.+np.exp(-3.*(z-x)))**-1,decimals=16)


def grad_f2_x(x,z):
    '''Gradient of f2 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(-3.*(z-x))
    return -3*exponent_ / (exponent_+1)**2


def H1(rho_east_west,rho_bar,rho_tilde):
    '''KL divergence term
    rho_tilde is the initial condition
    rho_bar is the current distribution'''
    return ( xlogy(rho_east_west,rho_bar)-xlogy(rho_east_west,rho_tilde) ) /10.


def V1(x,z):
    return f1(x,z)


def W_consensus(xj,xi,dx):
    x_ = np.abs(xj-xi)
    return -0.05*(1. + x_**2)**-1*dx


def table_dist(z_i,mu):
    '''distribution shaped like a table, with mean mu'''
    # cut a normal in half, add uniform between
    sig = 0.2
    factor = 4. # change this to change the width
    denom = np.sqrt(2*np.pi*sig) * factor
    bnd   = denom / 2 *(1.-1./factor)
    z_uni_idx   = np.abs(z_i-mu)<bnd
    z_right_idx = (z_i-mu)>= bnd
    z_left_idx  = (z_i-mu)<= -bnd
    dist = np.zeros(np.shape(z_i))
    dist[z_uni_idx]   = 1./ denom
    dist[z_left_idx]  = np.exp(-(z_i[z_left_idx]-mu+bnd)**2/2./sig) / denom
    dist[z_right_idx] = np.exp(-(z_i[z_right_idx]-mu-bnd)**2/2./sig)/ denom
    print("area",np.trapz(dist,z_i))
    return dist




####################### select experiments ###############################
experiments = [1]

################ experiment 4: aligned goals, with kernel #####################
if 1 in experiments:
    print("Experiment 1")
    x0            = 2.
    T             = 40
    nT            = int(T/dt)
    mu            = 3.
    x_conv_rate   = 1.e0 # slow convergence
    aligned_V = lambda x,z : -V1(x,z)
    pde = PDE1d(dz,N,nT,H_prime_rho=H1,V=aligned_V,W=W_consensus,save_data=False)
    pde.set_initial_distribution(lambda z: table_dist(z,1),lambda z: table_dist(z,mu))
    ode = ODE1d(pde.z_i,dt,nT,pde.g0,x0,f1,f2,grad_f1_x,grad_f2_x,save_data=False,x_speed=x_conv_rate)
    rho = pde.rho0
    for t in tqdm.tqdm(range(0,nT)):
        x   = ode.update_x(rho,t)
        rho = pde.update_RK(x,t,dt)

    make_plots(pde,ode,"plots/experiment4_consensus_kernel",make_gif=False,plot_kernel=False)

